/* Original code has been submitted by Liu Liu.
   ----------------------------------------------------------------------------------
   * Spill-Tree for Approximate KNN Search
   * Author: Liu Liu
   * mailto: liuliu.1987+opencv@gmail.com
   * Refer to Paper:
   * An Investigation of Practical Approximate Nearest Neighbor Algorithms
   * cvMergeSpillTree TBD
   *
   * Redistribution and use in source and binary forms, with or
   * without modification, are permitted provided that the following
   * conditions are met:
   * 	Redistributions of source code must retain the above
   * 	copyright notice, this list of conditions and the following
   * 	disclaimer.
   * 	Redistributions in binary form must reproduce the above
   * 	copyright notice, this list of conditions and the following
   * 	disclaimer in the documentation and/or other materials
   * 	provided with the distribution.
   * 	The name of Contributor may not be used to endorse or
   * 	promote products derived from this software without
   * 	specific prior written permission.
   *
   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
   * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
   * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
   * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
   * DISCLAIMED. IN NO EVENT SHALL THE CONTRIBUTORS BE LIABLE FOR ANY
   * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
   * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
   * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
   * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
   * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
   * TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
   * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
   * OF SUCH DAMAGE.
   */

#include "precomp.hpp"
#include "_featuretree.h"

struct CvSpillTreeNode {
	bool leaf; // is leaf or not (leaf is the point that have no more child)
	bool spill; // is not a non-overlapping point (defeatist search)
	CvSpillTreeNode* lc; // left child (<)
	CvSpillTreeNode* rc; // right child (>)
	int cc; // child count
	CvMat* u; // projection vector
	CvMat* center; // center
	int i; // original index
	double r; // radius of remaining feature point
	double ub; // upper bound
	double lb; // lower bound
	double mp; // mean point
	double p; // projection value
};

struct CvSpillTree {
	CvSpillTreeNode* root;
	CvMat** refmat; // leaf ref matrix
	bool* cache; // visited or not
	int total; // total leaves
	int naive; // under this value, we perform naive search
	int type; // mat type
	double rho; // under this value, it is a spill tree
	double tau; // the overlapping buffer ratio
};

// find the farthest node in the "list" from "node"
static inline CvSpillTreeNode*
icvFarthestNode( CvSpillTreeNode* node,
				 CvSpillTreeNode* list,
				 int total ) {
	double farthest = -1.;
	CvSpillTreeNode* result = NULL;
	for ( int i = 0; i < total; i++ ) {
		double norm = cvNorm( node->center, list->center );
		if ( norm > farthest ) {
			farthest = norm;
			result = list;
		}
		list = list->rc;
	}
	return result;
}

// clone a new tree node
static inline CvSpillTreeNode*
icvCloneSpillTreeNode( CvSpillTreeNode* node ) {
	CvSpillTreeNode* result = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
	memcpy( result, node, sizeof(CvSpillTreeNode) );
	return result;
}

// append the link-list of a tree node
static inline void
icvAppendSpillTreeNode( CvSpillTreeNode* node,
						CvSpillTreeNode* append ) {
	if ( node->lc == NULL ) {
		node->lc = node->rc = append;
		node->lc->lc = node->rc->rc = NULL;
	} else {
		append->lc = node->rc;
		append->rc = NULL;
		node->rc->rc = append;
		node->rc = append;
	}
	node->cc++;
}

#define _dispatch_mat_ptr(x, step) (CV_MAT_DEPTH((x)->type) == CV_32F ? (void*)((x)->data.fl+(step)) : (CV_MAT_DEPTH((x)->type) == CV_64F ? (void*)((x)->data.db+(step)) : (void*)(0)))

static void
icvDFSInitSpillTreeNode( const CvSpillTree* tr,
						 const int d,
						 CvSpillTreeNode* node ) {
	if ( node->cc <= tr->naive ) {
		// already get to a leaf, terminate the recursion.
		node->leaf = true;
		node->spill = false;
		return;
	}

	// random select a node, then find a farthest node from this one, then find a farthest from that one...
	// to approximate the farthest node-pair
	static CvRNG rng_state = cvRNG(0xdeadbeef);
	int rn = cvRandInt( &rng_state ) % node->cc;
	CvSpillTreeNode* lnode = NULL;
	CvSpillTreeNode* rnode = node->lc;
	for ( int i = 0; i < rn; i++ ) {
		rnode = rnode->rc;
	}
	lnode = icvFarthestNode( rnode, node->lc, node->cc );
	rnode = icvFarthestNode( lnode, node->lc, node->cc );

	// u is the projection vector
	node->u = cvCreateMat( 1, d, tr->type );
	cvSub( lnode->center, rnode->center, node->u );
	cvNormalize( node->u, node->u );

	// find the center of node in hyperspace
	node->center = cvCreateMat( 1, d, tr->type );
	cvZero( node->center );
	CvSpillTreeNode* it = node->lc;
	for ( int i = 0; i < node->cc; i++ ) {
		cvAdd( it->center, node->center, node->center );
		it = it->rc;
	}
	cvConvertScale( node->center, node->center, 1. / node->cc );

	// project every node to "u", and find the mean point "mp"
	it = node->lc;
	node->r = -1.;
	node->mp = 0;
	for ( int i = 0; i < node->cc; i++ ) {
		node->mp += ( it->p = cvDotProduct( it->center, node->u ) );
		double norm = cvNorm( node->center, it->center );
		if ( norm > node->r ) {
			node->r = norm;
		}
		it = it->rc;
	}
	node->mp = node->mp / node->cc;

	// overlapping buffer and upper bound, lower bound
	double ob = (lnode->p - rnode->p) * tr->tau * .5;
	node->ub = node->mp + ob;
	node->lb = node->mp - ob;
	int sl = 0, l = 0;
	int sr = 0, r = 0;
	it = node->lc;
	for ( int i = 0; i < node->cc; i++ ) {
		if ( it->p <= node->ub ) {
			sl++;
		}
		if ( it->p >= node->lb ) {
			sr++;
		}
		if ( it->p < node->mp ) {
			l++;
		} else {
			r++;
		}
		it = it->rc;
	}
	// precision problem, return the node as it is.
	if (( l == 0 ) || ( r == 0 )) {
		cvReleaseMat( &(node->u) );
		cvReleaseMat( &(node->center) );
		node->leaf = true;
		node->spill = false;
		return;
	}
	CvSpillTreeNode* lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
	memset(lc, 0, sizeof(CvSpillTreeNode));
	CvSpillTreeNode* rc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
	memset(rc, 0, sizeof(CvSpillTreeNode));
	lc->lc = lc->rc = rc->lc = rc->rc = NULL;
	lc->cc = rc->cc = 0;
	int undo = cvRound(node->cc * tr->rho);
	if (( sl >= undo ) || ( sr >= undo )) {
		// it is not a spill point (defeatist search disabled)
		it = node->lc;
		for ( int i = 0; i < node->cc; i++ ) {
			CvSpillTreeNode* next = it->rc;
			if ( it->p < node->mp ) {
				icvAppendSpillTreeNode( lc, it );
			} else {
				icvAppendSpillTreeNode( rc, it );
			}
			it = next;
		}
		node->spill = false;
	} else {
		// a spill point
		it = node->lc;
		for ( int i = 0; i < node->cc; i++ ) {
			CvSpillTreeNode* next = it->rc;
			if ( it->p < node->lb ) {
				icvAppendSpillTreeNode( lc, it );
			} else if ( it->p > node->ub ) {
				icvAppendSpillTreeNode( rc, it );
			} else {
				CvSpillTreeNode* cit = icvCloneSpillTreeNode( it );
				icvAppendSpillTreeNode( lc, it );
				icvAppendSpillTreeNode( rc, cit );
			}
			it = next;
		}
		node->spill = true;
	}
	node->lc = lc;
	node->rc = rc;

	// recursion process
	icvDFSInitSpillTreeNode( tr, d, node->lc );
	icvDFSInitSpillTreeNode( tr, d, node->rc );
}

static CvSpillTree*
icvCreateSpillTree( const CvMat* raw_data,
					const int naive,
					const double rho,
					const double tau ) {
	int n = raw_data->rows;
	int d = raw_data->cols;

	CvSpillTree* tr = (CvSpillTree*)cvAlloc( sizeof(CvSpillTree) );
	tr->root = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
	memset(tr->root, 0, sizeof(CvSpillTreeNode));
	tr->refmat = (CvMat**)cvAlloc( sizeof(CvMat*) * n );
	tr->cache = (bool*)cvAlloc( sizeof(bool) * n );
	tr->total = n;
	tr->naive = naive;
	tr->rho = rho;
	tr->tau = tau;
	tr->type = raw_data->type;

	// tie a link-list to the root node
	tr->root->lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
	memset(tr->root->lc, 0, sizeof(CvSpillTreeNode));
	tr->root->lc->center = cvCreateMatHeader( 1, d, tr->type );
	cvSetData( tr->root->lc->center, _dispatch_mat_ptr(raw_data, 0), raw_data->step );
	tr->refmat[0] = tr->root->lc->center;
	tr->root->lc->lc = NULL;
	tr->root->lc->leaf = true;
	tr->root->lc->i = 0;
	CvSpillTreeNode* node = tr->root->lc;
	for ( int i = 1; i < n; i++ ) {
		CvSpillTreeNode* newnode = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
		memset(newnode, 0, sizeof(CvSpillTreeNode));
		newnode->center = cvCreateMatHeader( 1, d, tr->type );
		cvSetData( newnode->center, _dispatch_mat_ptr(raw_data, i * d), raw_data->step );
		tr->refmat[i] = newnode->center;
		newnode->lc = node;
		newnode->i = i;
		newnode->leaf = true;
		newnode->rc = NULL;
		node->rc = newnode;
		node = newnode;
	}
	tr->root->rc = node;
	tr->root->cc = n;
	icvDFSInitSpillTreeNode( tr, d, tr->root );
	return tr;
}

static void
icvSpillTreeNodeHeapify( CvSpillTreeNode** heap,
						 int i,
						 const int k ) {
	if ( heap[i] == NULL ) {
		return;
	}
	int l, r, largest = i;
	CvSpillTreeNode* inp;
	do {
		i = largest;
		r = (i + 1) << 1;
		l = r - 1;
		if (( l < k ) && ( heap[l] == NULL )) {
			largest = l;
		} else if (( r < k ) && ( heap[r] == NULL )) {
			largest = r;
		} else {
			if (( l < k ) && ( heap[l]->mp > heap[i]->mp )) {
				largest = l;
			}
			if (( r < k ) && ( heap[r]->mp > heap[largest]->mp )) {
				largest = r;
			}
		}
		if ( largest != i ) {
			CV_SWAP( heap[largest], heap[i], inp );
		}
	} while ( largest != i );
}

static void
icvSpillTreeDFSearch( CvSpillTree* tr,
					  CvSpillTreeNode* node,
					  CvSpillTreeNode** heap,
					  int* es,
					  const CvMat* desc,
					  const int k,
					  const int emax ) {
	if ((emax > 0) && ( *es >= emax )) {
		return;
	}
	double dist, p = 0;
	while ( node->spill ) {
		// defeatist search
		if ( !node->leaf ) {
			p = cvDotProduct( node->u, desc );
		}
		if ( p < node->lb && node->lc->cc >= k ) { // check the number of children larger than k otherwise you'll skip over better neighbor
			node = node->lc;
		} else if ( p > node->ub && node->rc->cc >= k ) {
			node = node->rc;
		} else {
			break;
		}
		if ( NULL == node ) {
			return;
		}
	}
	if ( node->leaf ) {
		// a leaf, naive search
		CvSpillTreeNode* it = node->lc;
		for ( int i = 0; i < node->cc; i++ ) {
			if ( !tr->cache[it->i] ) {
				it->mp = cvNorm( it->center, desc );
				tr->cache[it->i] = true;
				if (( heap[0] == NULL) || ( it->mp < heap[0]->mp )) {
					heap[0] = it;
					icvSpillTreeNodeHeapify( heap, 0, k );
					(*es)++;
				}
			}
			it = it->rc;
		}
		return;
	}
	dist = cvNorm( node->center, desc );
	// impossible case, skip
	if (( heap[0] != NULL ) && ( dist - node->r > heap[0]->mp )) {
		return;
	}
	p = cvDotProduct( node->u, desc );
	// guided dfs
	if ( p < node->mp ) {
		icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax );
		icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax );
	} else {
		icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax );
		icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax );
	}
}

static void
icvFindSpillTreeFeatures( CvSpillTree* tr,
						  const CvMat* desc,
						  CvMat* results,
						  CvMat* dist,
						  const int k,
						  const int emax ) {
	assert( desc->type == tr->type );
	CvSpillTreeNode** heap = (CvSpillTreeNode**)cvAlloc( k * sizeof(heap[0]) );
	for ( int j = 0; j < desc->rows; j++ ) {
		CvMat _desc = cvMat( 1, desc->cols, desc->type, _dispatch_mat_ptr(desc, j * desc->cols) );
		for ( int i = 0; i < k; i++ ) {
			heap[i] = NULL;
		}
		memset( tr->cache, 0, sizeof(bool)*tr->total );
		int es = 0;
		icvSpillTreeDFSearch( tr, tr->root, heap, &es, &_desc, k, emax );
		CvSpillTreeNode* inp;
		for ( int i = k - 1; i > 0; i-- ) {
			CV_SWAP( heap[i], heap[0], inp );
			icvSpillTreeNodeHeapify( heap, 0, i );
		}
		int* rs = results->data.i + j * results->cols;
		double* dt = dist->data.db + j * dist->cols;
		for ( int i = 0; i < k; i++, rs++, dt++ )
			if ( heap[i] != NULL ) {
				*rs = heap[i]->i;
				*dt = heap[i]->mp;
			} else {
				*rs = -1;
			}
	}
	cvFree( &heap );
}

static void
icvDFSReleaseSpillTreeNode( CvSpillTreeNode* node ) {
	if ( node->leaf ) {
		CvSpillTreeNode* it = node->lc;
		for ( int i = 0; i < node->cc; i++ ) {
			CvSpillTreeNode* s = it;
			it = it->rc;
			cvFree( &s );
		}
	} else {
		cvReleaseMat( &node->u );
		cvReleaseMat( &node->center );
		icvDFSReleaseSpillTreeNode( node->lc );
		icvDFSReleaseSpillTreeNode( node->rc );
	}
	cvFree( &node );
}

static void
icvReleaseSpillTree( CvSpillTree** tr ) {
	for ( int i = 0; i < (*tr)->total; i++ ) {
		cvReleaseMat( &((*tr)->refmat[i]) );
	}
	cvFree( &((*tr)->refmat) );
	cvFree( &((*tr)->cache) );
	icvDFSReleaseSpillTreeNode( (*tr)->root );
	cvFree( tr );
}

class CvSpillTreeWrap : public CvFeatureTree {
	CvSpillTree* tr;
public:
	CvSpillTreeWrap(const CvMat* raw_data,
					const int naive,
					const double rho,
					const double tau) {
		tr = icvCreateSpillTree(raw_data, naive, rho, tau);
	}
	~CvSpillTreeWrap() {
		icvReleaseSpillTree(&tr);
	}

	void FindFeatures(const CvMat* desc, int k, int emax, CvMat* results, CvMat* dist) {
		icvFindSpillTreeFeatures(tr, desc, results, dist, k, emax);
	}
};

CvFeatureTree* cvCreateSpillTree( const CvMat* raw_data,
								  const int naive,
								  const double rho,
								  const double tau ) {
	return new CvSpillTreeWrap(raw_data, naive, rho, tau);
}
